{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import os.path as osp\n",
    "import json\n",
    "from nltk.stem import WordNetLemmatizer\n",
    "wordnet_lemmatizer = WordNetLemmatizer()\n",
    "\n",
    "TOK = 'question_tok_concol'\n",
    "TYPE = 'question_type_concol_list'\n",
    "HEADER = 'header_tok'\n",
    "# randomly chosen NONE string\n",
    "NONE = \"te8r2ed\" \n",
    "DEFAULT_TABLENAME = \"DEFAULT_NAME\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "#dpath = \"/data/projects/nl2sql/models/datasets/data_wikisql_augment/train.json\"\n",
    "#dpath = \"/data/projects/nl2sql/datasets/data/train.json\"\n",
    "dpath = \"/data/projects/nl2sql/datasets/data_radn_split/test_radn.json\"\n",
    "tpath = \"/data/projects/nl2sql/datasets/data/tables.json\"\n",
    "\n",
    "with open(dpath) as f:\n",
    "    data = json.load(f)\n",
    "with open(tpath) as f:\n",
    "    tables = json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "wikisql_path = \"/data/projects/wikitable_QA/SQLNet/data/train_final_concol.jsonl\"\n",
    "wikitables_path = \"/data/projects/wikitable_QA/SQLNet/data/train_tok.tables.jsonl\"\n",
    "\n",
    "with open(wikisql_path) as f:\n",
    "    wiki_cont = [json.loads(l.strip()) for l in f]\n",
    "    \n",
    "wiki_dict = {}\n",
    "for entry in wiki_cont:\n",
    "    wiki_dict[entry['question'].lower()] = {\n",
    "        TYPE: entry[TYPE],\n",
    "        TOK: entry[TOK]\n",
    "    }\n",
    "\n",
    "with open(wikitables_path) as f:\n",
    "    wikitables = [json.loads(l.strip()) for l in f]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "tables_dict = {}\n",
    "for table in tables:\n",
    "    tables_dict[table['db_id']] = table\n",
    "\n",
    "wikitables_dict = {}\n",
    "for table in wikitables:\n",
    "    wikitables_dict[table['id']] = table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def isNumber(val):\n",
    "    try:\n",
    "        int(val)\n",
    "        return True\n",
    "    except:\n",
    "        return False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def group_table(toks, idx, num_toks, table_names):\n",
    "    table_toks = [name.split() for name in table_names]\n",
    "    for endIdx in reversed(range(idx+1, num_toks+1)):\n",
    "        sub_toks = toks[idx: endIdx]\n",
    "        if sub_toks in table_toks:\n",
    "            return endIdx, sub_toks\n",
    "    return idx, None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "collapsed": true,
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "def get_wikitable_name(table):\n",
    "    if 'page_title' in table:\n",
    "        return table['page_title'].lower().split()\n",
    "    \n",
    "    if 'section_title' in table:\n",
    "        return table['section_title'].lower().split()\n",
    "    \n",
    "    if 'caption' in table:\n",
    "        return table['caption'].lower().split()\n",
    "    \n",
    "    return [DEFAULT_TABLENAME]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def parse_processed_wiki(entry):\n",
    "    table = wikitables_dict[entry['db_id']]\n",
    "    parsed_wikisql = wiki_dict[entry['question'].lower()]\n",
    "    entry.update(parsed_wikisql)\n",
    "    headers = table[HEADER]\n",
    "    table_name = get_wikitable_name(table)\n",
    "    for i in range(len(entry[TYPE])):\n",
    "        if entry[TYPE][i] == ['column']:  # 'column' to table name\n",
    "            entry[TYPE][i] = table_name\n",
    "        elif entry[TYPE][i] in headers:   # value to [column name, table name]\n",
    "            entry[TYPE][i] += table_name\n",
    "    return entry"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def group_header(toks, idx, num_toks, header_toks):\n",
    "    for endIdx in reversed(range(idx+1, num_toks+1)):\n",
    "        sub_toks = toks[idx: endIdx]\n",
    "        if sub_toks in header_toks:\n",
    "            return endIdx, sub_toks\n",
    "    return idx, None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def group_val(toks, idx, num_toks, val2col):\n",
    "    if toks[idx] in val2col:\n",
    "        return idx + 1, val2col[toks[idx]]\n",
    "    return idx, None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "res = []\n",
    "remain = []\n",
    "for entry in data:\n",
    "    # entry is processed before\n",
    "    if entry['question'].lower() in wiki_dict:\n",
    "        res.append(parse_processed_wiki(entry))\n",
    "    else:\n",
    "        remain.append(entry)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "for entry in remain:\n",
    "    question_toks = [tok.lower() for tok in entry['question_toks']]\n",
    "    question_toks_lem = [wordnet_lemmatizer.lemmatize(t) for t in question_toks]\n",
    "    col2table = {}\n",
    "    if entry['db_id'] in wikitables_dict:\n",
    "        table = wikitables_dict[entry['db_id']]\n",
    "        header_toks = table['header_tok']\n",
    "        table_name = get_wikitable_name(table)\n",
    "        for col in header_toks:\n",
    "            col2table[\" \".join(col)] = table_name\n",
    "    else:\n",
    "        table = tables_dict[entry['db_id']]\n",
    "        header_toks = []\n",
    "        for col in table['column_names']:\n",
    "            this_header_tok = col[1].split()\n",
    "            header_toks.append(this_header_tok)\n",
    "            if col[0] < 0: # \"*\"\n",
    "                col2table[\" \".join(this_header_tok)] = [\"all\"]\n",
    "            else:\n",
    "                col2table[\" \".join(this_header_tok)] = \\\n",
    "                    table['table_names'][col[0]].split()\n",
    "    table_names = set([\" \".join(vs) for vs in col2table.values()])\n",
    "    # build val2col dictionary\n",
    "    val2col = {}\n",
    "    col_stacks = []\n",
    "    for tok1, tok2 in zip(entry['query_toks_no_value'], entry['query_toks']):\n",
    "        if tok1.split() in header_toks:\n",
    "            col_stacks.append(tok1.split())\n",
    "        elif tok1 == 'value' and len(col_stacks) > 0:\n",
    "            val2col[tok2] = col_stacks[-1]\n",
    "\n",
    "    idx = 0\n",
    "    num_toks = len(question_toks)\n",
    "    tok_concol = []\n",
    "    type_concol = []\n",
    "    \n",
    "    while idx < num_toks:\n",
    "        end_idx, tname = group_table(question_toks_lem, idx, num_toks, table_names)\n",
    "        if tname:\n",
    "            tok_concol.append(question_toks[idx: end_idx])\n",
    "            type_concol.append([\"table\"])\n",
    "            idx = end_idx\n",
    "            continue\n",
    "            \n",
    "        end_idx, header = group_header(question_toks_lem, idx, num_toks, header_toks)\n",
    "        if header:\n",
    "            tok_concol.append(question_toks[idx: end_idx])\n",
    "            type_concol.append(col2table[\" \".join(header)])\n",
    "            idx = end_idx\n",
    "            continue\n",
    "    \n",
    "        end_idx, col = group_val(question_toks_lem, idx, num_toks, val2col)\n",
    "        if col:\n",
    "            tok_concol.append([question_toks[idx - 1]])\n",
    "            if isNumber(question_toks[idx - 1]):\n",
    "                type_concol.append(col + col2table[\" \".join(col)] + [\"number\"])\n",
    "            else:\n",
    "                type_concol.append(col + col2table[\" \".join(col)])\n",
    "            idx = end_idx\n",
    "            continue\n",
    "    \n",
    "        tok_concol.append([question_toks[idx]])\n",
    "        type_concol.append([NONE])\n",
    "        idx += 1\n",
    "    \n",
    "    entry[TOK] = tok_concol\n",
    "    entry[TYPE] = type_concol\n",
    "    \n",
    "    res.append(entry)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "with open('test_radn_type.json', 'w') as f:\n",
    "    json.dump(obj=res, fp=f, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
